import numpy as np
import torch


def save_ckpt(save_name, model, model_ema, optimizer=None, scheduler=None, epoch=None):
    model_weight = model.state_dict()
    model_ema_weight = model_ema.state_dict()
    torch.save(
        {
            "model": model_weight,
            "model_ema": model_ema_weight,
            "optimizer": optimizer,
            "scheduler": scheduler,
            "epoch": epoch,
        },
        save_name,
    )


